{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using data [Data and variances]\n"
     ]
    }
   ],
   "source": [
    "# Generative Adversarial Networks (GAN) example in PyTorch.\n",
    "# See related blog post at https://medium.com/@devnag/generative-adversarial-networks-gans-in-50-lines-of-code-pytorch-e81b79659e3f#.sch4xgsa9\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    " \n",
    "# Data params\n",
    "data_mean = 4\n",
    "data_stddev = 1.25\n",
    " \n",
    "# Model params\n",
    "g_input_size = 1     # Random noise dimension coming into generator, per output vector\n",
    "g_hidden_size = 50   # Generator complexity\n",
    "g_output_size = 1    # size of generated output vector\n",
    "d_input_size = 100   # Minibatch size - cardinality of distributions\n",
    "d_hidden_size = 50   # Discriminator complexity\n",
    "d_output_size = 1    # Single dimension for 'real' vs. 'fake'\n",
    "minibatch_size = d_input_size\n",
    " \n",
    "d_learning_rate = 2e-4  # 2e-4\n",
    "g_learning_rate = 2e-4\n",
    "optim_betas = (0.9, 0.999)\n",
    "num_epochs = 30000\n",
    "print_interval = 200\n",
    "d_steps = 1  # 'k' steps in the original GAN paper. Can put the discriminator on higher training freq than generator\n",
    "g_steps = 1\n",
    " \n",
    "# ### Uncomment only one of these\n",
    "#(name, preprocess, d_input_func) = (\"Raw data\", lambda data: data, lambda x: x)\n",
    "(name, preprocess, d_input_func) = (\"Data and variances\", lambda data: decorate_with_diffs(data, 2.0), lambda x: x * 2)\n",
    " \n",
    "print(\"Using data [%s]\" % (name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ##### DATA: Target data and generator input data\n",
    " \n",
    "def get_distribution_sampler(mu, sigma):\n",
    "    return lambda n: torch.Tensor(np.random.normal(mu, sigma, (1, n)))  # Gaussian\n",
    " \n",
    "def get_generator_input_sampler():\n",
    "    return lambda m, n: torch.rand(m, n)  # Uniform-dist data into generator, _NOT_ Gaussian"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ##### MODELS: Generator model and discriminator model\n",
    " \n",
    "class Generator(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, output_size):\n",
    "        super(Generator, self).__init__()\n",
    "        self.map1 = nn.Linear(input_size, hidden_size)\n",
    "        self.map2 = nn.Linear(hidden_size, hidden_size)\n",
    "        self.map3 = nn.Linear(hidden_size, output_size)\n",
    " \n",
    "    def forward(self, x):\n",
    "        x = F.elu(self.map1(x))\n",
    "        x = F.sigmoid(self.map2(x))\n",
    "        return self.map3(x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Discriminator(nn.Module):\n",
    "    def __init__(self, input_size, hidden_size, output_size):\n",
    "        super(Discriminator, self).__init__()\n",
    "        self.map1 = nn.Linear(input_size, hidden_size)\n",
    "        self.map2 = nn.Linear(hidden_size, hidden_size)\n",
    "        self.map3 = nn.Linear(hidden_size, output_size)\n",
    " \n",
    "    def forward(self, x):\n",
    "        x = F.elu(self.map1(x))\n",
    "        x = F.elu(self.map2(x))\n",
    "        return F.sigmoid(self.map3(x))\n",
    " \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract(v):\n",
    "    return v.data.storage().tolist()\n",
    " \n",
    "def stats(d):\n",
    "    return [np.mean(d), np.std(d)]\n",
    " \n",
    "def decorate_with_diffs(data, exponent):\n",
    "    mean = torch.mean(data.data, 1)\n",
    "    mean_broadcast = torch.mul(torch.ones(data.size()), mean.tolist()[0][0])\n",
    "    diffs = torch.pow(data - Variable(mean_broadcast), exponent)\n",
    "    return torch.cat([data, diffs], 1)\n",
    " \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "'float' object is not subscriptable",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-6-121557ea1b2d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     14\u001b[0m         \u001b[0;31m#  1A: Train D on real\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     15\u001b[0m         \u001b[0md_real_data\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mVariable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md_sampler\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md_input_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 16\u001b[0;31m         \u001b[0md_real_decision\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mD\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpreprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0md_real_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     17\u001b[0m \u001b[0;31m#         d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))  # ones = true\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     18\u001b[0m \u001b[0;31m#         d_real_error.backward() # compute/store gradients, but don't change params\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-1-4a9233a901c2>\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m     31\u001b[0m \u001b[0;31m# ### Uncomment only one of these\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     32\u001b[0m \u001b[0;31m#(name, preprocess, d_input_func) = (\"Raw data\", lambda data: data, lambda x: x)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m \u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpreprocess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0md_input_func\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m\"Data and variances\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mdecorate_with_diffs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2.0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     34\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     35\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Using data [%s]\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-5-6f340e739453>\u001b[0m in \u001b[0;36mdecorate_with_diffs\u001b[0;34m(data, exponent)\u001b[0m\n\u001b[1;32m      7\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdecorate_with_diffs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexponent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m     \u001b[0mmean\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmean\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 9\u001b[0;31m     \u001b[0mmean_broadcast\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mones\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmean\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtolist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     10\u001b[0m     \u001b[0mdiffs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mVariable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmean_broadcast\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexponent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     11\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdiffs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mTypeError\u001b[0m: 'float' object is not subscriptable"
     ]
    }
   ],
   "source": [
    "d_sampler = get_distribution_sampler(data_mean, data_stddev)\n",
    "gi_sampler = get_generator_input_sampler()\n",
    "G = Generator(input_size=g_input_size, hidden_size=g_hidden_size, output_size=g_output_size)\n",
    "D = Discriminator(input_size=d_input_func(d_input_size), hidden_size=d_hidden_size, output_size=d_output_size)\n",
    "criterion = nn.BCELoss()  # Binary cross entropy: http://pytorch.org/docs/nn.html#bceloss\n",
    "d_optimizer = optim.Adam(D.parameters(), lr=d_learning_rate, betas=optim_betas)\n",
    "g_optimizer = optim.Adam(G.parameters(), lr=g_learning_rate, betas=optim_betas)\n",
    " \n",
    "for epoch in range(num_epochs):\n",
    "    for d_index in range(d_steps):\n",
    "        # 1. Train D on real+fake\n",
    "        D.zero_grad()\n",
    " \n",
    "        #  1A: Train D on real\n",
    "        d_real_data = Variable(d_sampler(d_input_size))\n",
    "        d_real_decision = D(preprocess(d_real_data))\n",
    "#         d_real_error = criterion(d_real_decision, Variable(torch.ones(1)))  # ones = true\n",
    "#         d_real_error.backward() # compute/store gradients, but don't change params\n",
    " \n",
    "#         #  1B: Train D on fake\n",
    "#         d_gen_input = Variable(gi_sampler(minibatch_size, g_input_size))\n",
    "#         d_fake_data = G(d_gen_input).detach()  # detach to avoid training G on these labels\n",
    "#         d_fake_decision = D(preprocess(d_fake_data.t()))\n",
    "#         d_fake_error = criterion(d_fake_decision, Variable(torch.zeros(1)))  # zeros = fake\n",
    "#         d_fake_error.backward()\n",
    "#         d_optimizer.step()     # Only optimizes D's parameters; changes based on stored gradients from backward()\n",
    " \n",
    "#     for g_index in range(g_steps):\n",
    "#         # 2. Train G on D's response (but DO NOT train D on these labels)\n",
    "#         G.zero_grad()\n",
    " \n",
    "#         gen_input = Variable(gi_sampler(minibatch_size, g_input_size))\n",
    "#         g_fake_data = G(gen_input)\n",
    "#         dg_fake_decision = D(preprocess(g_fake_data.t()))\n",
    "#         g_error = criterion(dg_fake_decision, Variable(torch.ones(1)))  # we want to fool, so pretend it's all genuine\n",
    " \n",
    "#         g_error.backward()\n",
    "#         g_optimizer.step()  # Only optimizes G's parameters\n",
    " \n",
    "#     if epoch % print_interval == 0:\n",
    "#         print(\"%s: D: %s/%s G: %s (Real: %s, Fake: %s) \" % (epoch,\n",
    "#                                                             extract(d_real_error)[0],\n",
    "#                                                             extract(d_fake_error)[0],\n",
    "#                                                             extract(g_error)[0],\n",
    "#                                                             stats(extract(d_real_data)),\n",
    "#                                                             stats(extract(d_fake_data))))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
